The JAX AI stack

noshadow

Marie-Hélène Burle

What is JAX?

High-performance accelerator-oriented array computing library for Python developed by Google

Composition, JIT-compilation, transformation, and automatic differentiation of numerical programs

NumPy-like and lower-level APIs

Requires strict functional programming

Why JAX?

Fast

  • Default data type suited for deep learning

    Like PyTorch, uses float32 as default. This level of precision is suitable for deep learning and increases efficiency (by contrast, NumPy defaults to float64)

  • JIT compilation

  • The same code can run on CPUs or on accelerators (GPUs and TPUs)

  • XLA (Accelerated Linear Algebra) optimization

  • Asynchronous dispatch

  • Vectorization, data parallelism, and sharding

    All levels of shared and distributed memory parallelism are supported

Great AD

01 Autodiff method 1 Static graph and XLA 02 Framework 2 Dynamic graph 1->2 a TensorFlow 4 Dynamic graph and XLA 2->4 b PyTorch 5 Pseudo-dynamic and XLA 4->5 d TensorFlow2 e JAX 03 Advantage 7 Mostly optimized AD 8 Convenient 9 Convenient 10 Convenient and mostly optimized AD 04 Disadvantage A Manual writing of IR B Limited AD optimization D Disappointing speed E Pure functions only (subset of Python)

  Summarized from a blog post by Chris Rackauckas

Close to the math

Considering the function f:

f = lambda x: x**3 + 2*x**2 - 3*x + 8

We can create a new function dfdx that computes the gradient of f w.r.t. x:

from jax import grad

dfdx = grad(f)

dfdx returns the derivatives:

print(dfdx(1.))
4.0

Forward and reverse modes

  • reverse-mode vector-Jacobian products: jax.vjp
  • forward-mode Jacobian-vector products: jax.jvp

Higher-order differentiation

With a single variable, the grad function calls can be nested:

d2fdx = grad(dfdx)   # function to compute 2nd order derivatives
d3fdx = grad(d2fdx)  # function to compute 3rd order derivatives
...

With several variables, you have to use the functions:

  • jax.jacfwd for forward-mode,
  • jax.jacrev for reverse-mode.

How does it work?

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit Transformation hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform Transformations py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit Just-in-time (JIT) compilation hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform Vectorization Parallelization   Differentiation   py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit jax.jit hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform jax.vmap jax.pmap jax.grad py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

JAX for AI

JAX itself is not a deep learning library …

jx JAX dl Deep learning jx->dl op Optimizers jx->op pp Probabilistic programming jx->pp pm Probabilistic modeling jx->pm ll LLMs ll->jx so Solvers so->jx ph Physics simulations ph->jx

… but a Python sublanguage ideal for deep learning

jx JAX dl Deep learning jx->dl op Optimizers jx->op pp Probabilistic programming jx->pp pm Probabilistic modeling jx->pm ll LLMs ll->jx so Solvers so->jx ph Physics simulations ph->jx

The JAX AI stack

Modular approach